# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""the module is used to process images."""

import copy
import itertools as it
import math

import cv2
import numpy as np
from PIL import Image

import mindspore.dataset.vision.py_transforms as PV

from mindspore.dataset.vision.c_transforms import Resize as Resize_
from mindspore.dataset.vision.c_transforms import Normalize as Normalize_

from mindvision.detection.datasets.pipelines.formatting import Collect
from mindvision.detection.datasets.utils.postprocess import (choose_candidate_by_constraints,
                                                             correct_bbox_by_candidates, rand_init,
                                                             proposal_crop_areas,
                                                             modify_annotation_by_proposal_crop_areas,
                                                             prior_box, match)

from mindvision.engine.class_factory import ClassFactory, ModuleType
from mindvision.engine.dataset.builder import build_transforms
from internals.anchor.defaultbox_ssd import ssd_bboxes_encode

@ClassFactory.register(ModuleType.PIPELINE)
class RandomExpand:
    """Expand operation for image."""

    def __init__(self,
                 mean=(0, 0, 0),
                 to_rgb=True,
                 ratio_range=(1, 4)):
        if to_rgb:
            self.mean = mean[::-1]
        else:
            self.mean = mean
        self.min_ratio, self.max_ratio = ratio_range

    def __call__(self, results):
        if np.random.randint(2):
            return results

        img = results.get("image")
        boxes = results.get("bboxes")

        h, w, c = img.shape
        ratio = np.random.uniform(self.min_ratio, self.max_ratio)
        expand_img = np.full((int(h * ratio), int(w * ratio), c),
                             self.mean).astype(img.dtype)

        left = int(np.random.uniform(0, w * ratio - w))
        top = int(np.random.uniform(0, h * ratio - h))
        expand_img[top:top + h, left:left + w] = img

        img = expand_img
        boxes += np.tile((left, top), 2)

        results['image'] = img
        results['bboxes'] = boxes

        if results.get("mask") is not None:
            mask = results.get("mask")
            mask_count, mask_h, mask_w = mask.shape
            expand_mask = np.zeros((mask_count, int(mask_h * ratio), int(mask_w * ratio))).astype(mask.dtype)
            expand_mask[:, top:top + h, left:left + w] = mask
            mask = expand_mask
            results['mask'] = mask
        return results


@ClassFactory.register(ModuleType.PIPELINE)
class Resize(Resize_):
    """Resize operation for image."""

    def __init__(self, img_width, img_height):
        self.img_width = img_width
        self.img_height = img_height

    def __call__(self, results):
        img = results.get("image")
        gt_bboxes = results.get("bboxes")

        img_data = img
        img_data = cv2.resize(
            img_data, (self.img_width, self.img_height), interpolation=cv2.INTER_LINEAR)

        h, w = img_data.shape[:2]
        h_scale = self.img_height / h
        w_scale = self.img_width / w

        scale_factor = np.array(
            [w_scale, h_scale, w_scale, h_scale], dtype=np.float32)

        img_shape = (self.img_height, self.img_width, 1.0)
        img_shape = np.asarray(img_shape, dtype=np.float32)

        gt_bboxes = gt_bboxes * scale_factor

        gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
        gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)

        results['image'] = img_data
        results['image_shape'] = img_shape
        results['bboxes'] = gt_bboxes

        return results


def filp_pil_image(img):
    return img.transpose(Image.FLIP_LEFT_RIGHT)


@ClassFactory.register(ModuleType.PIPELINE)
class ResizeWithinMultiScales:
    """
    Crop an image randomly with bounding box constraints.
    """

    def __init__(self,
                 max_boxes,
                 jitter,
                 max_trial,
                 flip=0.5,
                 rgb=(128, 128, 128),
                 use_constraints=False):
        """Constructor for ResizeWithinMultiScales"""
        self.max_boxes = max_boxes
        self.jitter = jitter
        self.max_trial = max_trial
        self.flip = flip
        self.rgb = rgb
        self.use_constraints = use_constraints

    def __call__(self, results):
        image = copy.deepcopy(results['image'])
        if not isinstance(image, Image.Image):
            image = Image.fromarray(image)
        box = results['bboxes']
        image_h, image_w = results['image_shape']
        input_h, input_w = results['resize_size']

        np.random.shuffle(box)
        if len(box) > self.max_boxes:
            box = box[:self.max_boxes]
        flip = rand_init() < self.flip

        box_data = np.zeros((self.max_boxes, 4))

        candidates = choose_candidate_by_constraints(use_constraints=False,
                                                     max_trial=10,
                                                     input_w=input_w,
                                                     input_h=input_h,
                                                     image_w=image_w,
                                                     image_h=image_h,
                                                     jitter=self.jitter,
                                                     box=box)
        box_data, candidate = correct_bbox_by_candidates(candidates=candidates,
                                                         input_w=input_w,
                                                         input_h=input_h,
                                                         image_w=image_w,
                                                         image_h=image_h,
                                                         flip=flip,
                                                         box=box,
                                                         box_data=box_data,
                                                         allow_outside_center=True)
        dx, dy, nw, nh = candidate
        interp = get_interp_method(interp=10)
        image = image.resize((nw, nh), pil_image_reshape(interp))

        # place image, gray color as back graoud
        new_image = Image.new('RGB', (input_w, input_h), tuple(self.rgb))
        new_image.paste(image, (dx, dy))
        image = new_image

        if flip:
            image = filp_pil_image(image)

        image = np.array(image)
        results['image'] = image
        results['bboxes'] = box_data
        results['image_shape'] = np.array(image.shape[:2], np.int32)
        return results


@ClassFactory.register(ModuleType.PIPELINE)
class PilResize:
    """Resize for Pil operation."""

    def __init__(self, resize_size, interp=9):
        """Constructor for PilResize."""
        self.output_h, self.output_w = resize_size
        self.interp = interp

    def __call__(self, results):
        """ Do resize. """
        image = results['image']
        if not isinstance(image, Image.Image):
            image = Image.fromarray(image)
        ori_w, ori_h = image.size
        interp = get_interp_method(
            interp=self.interp,
            sizes=(ori_h, ori_w, self.output_h, self.output_w)
        )

        image = image.resize(
            (self.output_w, self.output_h),
            pil_image_reshape(interp)
        )

        image = np.array(image)
        results['image'] = image
        results['image_shape'] = np.array([ori_w, ori_h], np.int32)
        return results


@ClassFactory.register(ModuleType.PIPELINE)
class Rescale:
    """Rescale operation for image."""

    def __init__(self, img_width, img_height):
        self.img_width = img_width
        self.img_height = img_height

    def __call__(self, results):
        img = results.get("image")
        gt_bboxes = results.get("bboxes")

        img_data, scale_factor = rescale_with_scale(img, (self.img_width, self.img_height))

        if img_data.shape[0] > self.img_height:
            img_data, scale_factor2 = rescale_with_scale(img_data, (self.img_height, self.img_height))
            scale_factor = scale_factor * scale_factor2

        gt_bboxes = gt_bboxes * scale_factor
        gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_data.shape[1] - 1)
        gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_data.shape[0] - 1)

        pad_h = self.img_height - img_data.shape[0]
        pad_w = self.img_width - img_data.shape[1]
        assert ((pad_h >= 0) and (pad_w >= 0))

        pad_img_data = np.zeros((self.img_height, self.img_width, 3)).astype(img_data.dtype)
        pad_img_data[0:img_data.shape[0], 0:img_data.shape[1], :] = img_data

        if results.get("mask") is not None:
            gt_mask = results.get("mask")
            gt_mask_data = np.array([
                rescale_with_factor(mask, scale_factor)
                for mask in gt_mask
            ])
            mask_count, mask_h, mask_w = gt_mask_data.shape
            pad_mask = np.zeros((mask_count, self.img_height, self.img_width)).astype(gt_mask_data.dtype)
            pad_mask[:, 0:mask_h, 0:mask_w] = gt_mask_data
            results['mask'] = pad_mask

        img_shape = (self.img_height, self.img_width, 1.0)
        img_shape = np.asarray(img_shape, dtype=np.float32)

        results['image'] = pad_img_data
        results['image_shape'] = img_shape
        results['bboxes'] = gt_bboxes

        return results


@ClassFactory.register(ModuleType.PIPELINE)
class RescaleWithoutGT:
    """Rescale operation for image."""

    def __init__(self, img_width, img_height):
        self.img_width = img_width
        self.img_height = img_height

    def __call__(self, results):
        img = results.get("image")
        img_shape = results.get("image_shape")

        img_data, scale_factor = rescale_with_scale(img, (self.img_width, self.img_height))

        if img_data.shape[0] > self.img_height:
            img_data, scale_factor2 = rescale_with_scale(img_data, (self.img_height, self.img_height))
            scale_factor = scale_factor * scale_factor2

        pad_h = self.img_height - img_data.shape[0]
        pad_w = self.img_width - img_data.shape[1]
        assert ((pad_h >= 0) and (pad_w >= 0))

        pad_img_data = np.zeros((self.img_height, self.img_width, 3)).astype(img_data.dtype)
        pad_img_data[0:img_data.shape[0], 0:img_data.shape[1], :] = img_data

        img_shape = np.append(img_shape, (scale_factor, scale_factor))
        img_shape = np.asarray(img_shape, dtype=np.float32)

        results['image'] = pad_img_data
        results['image_shape'] = img_shape

        return results


@ClassFactory.register(ModuleType.PIPELINE)
class _Normalize(Normalize_):
    """Imnormalize operation for image."""

    def __init__(self,
                 mean=None,
                 std=None,
                 to_rgb=True):
        if mean is None:
            mean = (123.675, 116.28, 103.53)
        if std is None:
            std = (58.395, 57.12, 57.375)
        self.mean = np.array(mean)
        self.std = np.array(std)
        self.to_rgb = to_rgb

    def __call__(self, results):
        img = results['image']
        img_data = img.copy().astype(np.float32)
        cv2.cvtColor(img_data, cv2.COLOR_BGR2RGB, img_data)  # inplace
        cv2.subtract(img_data, np.float64(self.mean.reshape(1, -1)), img_data)  # inplace
        cv2.multiply(img_data, 1 / np.float64(self.std.reshape(1, -1)), img_data)  # inplace
        results['image'] = img_data
        return results


@ClassFactory.register(ModuleType.PIPELINE)
class StaticNormalize:
    """Use for np.uint8."""

    def __init__(self,
                 statistic_norm=True,
                 mean=(0.485, 0.456, 0.406),
                 std=(0.229, 0.224, 0.225)):
        """Constructor for statistic_normalize_img."""
        self.statistic_norm = statistic_norm
        self.mean = mean
        self.std = std

    def __call__(self, results):
        # img: RGB
        img = results['image']
        if isinstance(img, Image.Image):
            img = np.array(img)
        img = img / 255.
        mean = np.array(self.mean)
        std = np.array(self.std)
        if self.statistic_norm:
            img = (img - mean) / std
        img = img.astype(np.float32)
        results['image'] = img
        return results


@ClassFactory.register(ModuleType.PIPELINE)
class RandomFlip:
    """Random flip operation."""

    def __init__(self, flip_ratio):
        self.flip_ratio = flip_ratio

    def __call__(self, results):
        flip = (np.random.rand() < self.flip_ratio)
        if not flip:
            return results

        img = results.get("image")
        gt_bboxes = results.get("bboxes")

        # flip image
        img_data = img
        img_data = np.flip(img_data, axis=1)

        # flip bboxes
        flipped = gt_bboxes.copy()
        _, w, _ = img_data.shape
        # flip bboxes horizontal
        flipped[..., 0::4] = w - gt_bboxes[..., 2::4] - 1
        flipped[..., 2::4] = w - gt_bboxes[..., 0::4] - 1
        if results.get("mask") is not None:
            gt_mask = results.get("mask")
            gt_mask_data = np.array([mask[:, ::-1] for mask in gt_mask])
            results['mask'] = gt_mask_data
        results['image'] = img_data
        results['bboxes'] = flipped

        return results


@ClassFactory.register(ModuleType.PIPELINE)
class RandomPilFlip:
    """Use PIL Apis to do flip operations.
    Args:
        direction : Only support "FLIP_LEFT_RIGHT" and "FLIP_TOP_BOTTOM".
    """

    def __init__(self, ratio=0.5, direction="FLIP_LEFT_RIGHT"):
        """Constructor for PilFilp."""
        if direction == "FLIP_LEFT_RIGHT":
            self.flip = Image.FLIP_LEFT_RIGHT
        else:
            self.flip = Image.FLIP_TOP_BOTTOM
        self.ratio = ratio

    def __call__(self, results):
        if rand_init() > self.ratio:
            return results
        image = results['image']
        if not isinstance(image, Image.Image):
            image = Image.fromarray(image)
        image.transpose(Image.FLIP_LEFT_RIGHT)
        image = np.array(image)
        results['image'] = image
        return results


@ClassFactory.register(ModuleType.PIPELINE)
class ColorDistortion:
    """Color distortion.
    Args:
        hue: hue val.
        sat: saturation.
        val: value.
    """

    def __init__(self, hue, saturation, value):
        """Constructor for ColorDistortion."""
        self.hue = hue
        self.saturation = saturation
        self.value = value

    def __call__(self, results):
        image = results['image']
        hue = rand_init(-self.hue, self.hue)
        sat = rand_init(1, self.saturation) if rand_init() < .5 else 1 / rand_init(1, self.saturation)
        val = rand_init(1, self.value) if rand_init() < .5 else 1 / rand_init(1, self.value)

        x = cv2.cvtColor(image, cv2.COLOR_RGB2HSV_FULL)
        x = x / 255.
        x[..., 0] += hue
        x[..., 0][x[..., 0] > 1] -= 1
        x[..., 0][x[..., 0] < 0] += 1
        x[..., 1] *= sat
        x[..., 2] *= val
        x[x > 1] = 1
        x[x < 0] = 0
        x = x * 255.
        x = x.astype(np.uint8)
        image = cv2.cvtColor(x, cv2.COLOR_HSV2RGB_FULL)
        results['image'] = image
        return results


@ClassFactory.register(ModuleType.PIPELINE)
class ConvertGrayToColor:
    """ Convert gray image to color image."""

    def __init__(self):
        pass

    def __call__(self, results):
        """Convert gray 2 color."""
        image = results['image']
        if len(image.shape) == 2:
            image = np.expand_dims(image, axis=-1)
            image = np.concatenate([image, image, image], axis=-1)
            results['image'] = image
        return results


@ClassFactory.register(ModuleType.PIPELINE)
class PerBatchCocoFormat:
    """ Collect for yolo """

    def __init__(self):
        pass

    def __call__(self, data_tuple):
        image = data_tuple[0]
        anno = data_tuple[1]
        resize_size = data_tuple[2]
        mosaic_flag = data_tuple[3]
        gt_box = anno[:, :4]
        gt_label = anno[:, 4]
        image_shape = np.array(image.shape[:2], np.int32)
        results = {
            'image': image,
            'annotation': anno,
            'image_shape': image_shape,
            'bboxes': gt_box,
            'labels': gt_label,
            'resize_size': resize_size,
            'mosaic_flag': mosaic_flag
        }
        return results


def intersect(box_a, box_b):
    """Compute the intersect of two sets of boxes."""
    max_yx = np.minimum(box_a[:, 2:4], box_b[2:4])
    min_yx = np.maximum(box_a[:, :2], box_b[:2])
    inter = np.clip((max_yx - min_yx), a_min=0, a_max=np.inf)
    return inter[:, 0] * inter[:, 1]


def jaccard_numpy(box_a, box_b):
    """Compute the jaccard overlap of two sets of boxes."""
    inter = intersect(box_a, box_b)
    area_a = ((box_a[:, 2] - box_a[:, 0]) *
              (box_a[:, 3] - box_a[:, 1]))
    area_b = ((box_b[2] - box_b[0]) *
              (box_b[3] - box_b[1]))
    union = area_a + area_b - inter
    return inter / union


def _rand(a=0., b=1.):
    """Generate random."""
    return np.random.rand() * (b - a) + a


def random_sample_crop(image, boxes):
    """Random Crop the image and boxes"""
    height, width, _ = image.shape
    min_iou = np.random.choice([None, 0.1, 0.3, 0.5, 0.7, 0.9])

    if min_iou is None:
        return image, boxes

    # max trails (50)
    for _ in range(50):
        image_t = image

        w = _rand(0.3, 1.0) * width
        h = _rand(0.3, 1.0) * height

        # aspect ratio constraint b/t .5 & 2
        if h / w < 0.5 or h / w > 2:
            continue

        left = _rand() * (width - w)
        top = _rand() * (height - h)

        rect = np.array([int(top), int(left), int(top + h), int(left + w)])
        overlap = jaccard_numpy(boxes, rect)

        # dropout some boxes
        drop_mask = overlap > 0
        if not drop_mask.any():
            continue

        if overlap[drop_mask].min() < min_iou and overlap[drop_mask].max() > (min_iou + 0.2):
            continue

        image_t = image_t[rect[0]:rect[2], rect[1]:rect[3], :]

        centers = (boxes[:, :2] + boxes[:, 2:4]) / 2.0

        m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1])
        m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1])

        # mask in that both m1 and m2 are true
        mask = m1 * m2 * drop_mask

        # have any valid boxes? try again if not
        if not mask.any():
            continue

        # take only matching gt boxes
        boxes_t = boxes[mask, :].copy()

        boxes_t[:, :2] = np.maximum(boxes_t[:, :2], rect[:2])
        boxes_t[:, :2] -= rect[:2]
        boxes_t[:, 2:4] = np.minimum(boxes_t[:, 2:4], rect[2:4])
        boxes_t[:, 2:4] -= rect[:2]

        return image_t, boxes_t
    return image, boxes


@ClassFactory.register(ModuleType.PIPELINE)
class Ssdpreprocess:

    """ Preprocess for yolo """

    def __init__(self):
        pass

    def __call__(self, image, img_id, image_shape, box):

        cv2.setNumThreads(2)

        box = box.astype(np.float32)
        image, box = random_sample_crop(image, box)
        ih, iw, _ = image.shape

        h = 640
        w = 640
        image = cv2.resize(image, (w, h))

        # img_id, image, image_shape, box, ih, iw

        flip = _rand() < .5
        if flip:
            image = cv2.flip(image, 1, dst=None)

        # When the channels of image is 1
        if len(image.shape) == 2:
            image = np.expand_dims(image, axis=-1)
            image = np.concatenate([image, image, image], axis=-1)

        box[:, [0, 2]] = box[:, [0, 2]] / ih
        box[:, [1, 3]] = box[:, [1, 3]] / iw

        if flip:
            box[:, [1, 3]] = 1 - box[:, [3, 1]]

        box, label, num_match = ssd_bboxes_encode(box)

        return image, box, label, num_match


@ClassFactory.register(ModuleType.PIPELINE)
class YoloBboxPreprocess:
    """ Data bbox preprocess for yolo."""

    def __init__(self,
                 anchors,
                 anchor_mask,
                 num_classes,
                 label_smooth,
                 label_smooth_factor,
                 iou_threshold,
                 max_boxes):
        """Constructor for YoloCollate."""
        self.anchors = anchors
        self.num_classes = num_classes
        self.label_smooth = label_smooth
        self.label_smooth_factor = label_smooth_factor
        self.iou_threshold = iou_threshold
        self.max_boxes = max_boxes
        self.anchor_mask = anchor_mask

    def __call__(self, results):
        anchors = np.array(self.anchors)
        num_layers = anchors.shape[0] // 3
        bboxes = results['bboxes'].tolist()
        labels = results['labels'].tolist()
        extend_labels = [0] * (len(bboxes) - len(labels))
        labels = labels + extend_labels
        true_boxes = []
        for bbox, label in zip(bboxes, labels):
            tmp = []
            tmp.extend(bbox)
            tmp.append(int(label))
            # tmp [x_min y_min x_max y_max, label]
            true_boxes.append(tmp)
        true_boxes = np.array(true_boxes, dtype='float32')
        input_shape = np.array(results['image_shape'], dtype='int32')
        boxes_xy = (true_boxes[..., 0:2] + true_boxes[..., 2:4]) // 2.
        # trans to box center point
        boxes_wh = true_boxes[..., 2:4] - true_boxes[..., 0:2]

        true_boxes[..., 0:2] = boxes_xy / input_shape[::-1]
        true_boxes[..., 2:4] = boxes_wh / input_shape[::-1]

        grid_shapes = [input_shape // 32, input_shape // 16, input_shape // 8]
        y_true = [np.zeros((grid_shapes[layer][0],
                            grid_shapes[layer][1],
                            len(self.anchor_mask[layer]),
                            5 + self.num_classes), dtype='float32') for layer in range(num_layers)]
        anchors = np.expand_dims(anchors, 0)
        anchors_max = anchors / 2.
        anchors_min = -anchors_max
        valid_mask = boxes_wh[..., 0] > 0

        wh = boxes_wh[valid_mask]
        if wh.size != 0:
            wh = np.expand_dims(wh, -2)
            # move to original point to compare, and choose the best layer-anchor to set
            boxes_max = wh / 2.
            boxes_min = -boxes_max

            intersect_min = np.maximum(boxes_min, anchors_min)
            intersect_max = np.minimum(boxes_max, anchors_max)
            intersect_wh = np.maximum(intersect_max - intersect_min, 0.)
            intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
            box_area = wh[..., 0] * wh[..., 1]
            anchor_area = anchors[..., 0] * anchors[..., 1]
            iou = intersect_area / (box_area + anchor_area - intersect_area)

            y_true = self.get_best_anchor_boxes(iou, num_layers, grid_shapes, true_boxes, y_true)

        pad_gt_box0, pad_gt_box1, pad_gt_box2 = self.pad_gt_boxes(y_true)
        results['bbox1'] = y_true[0]
        results['bbox2'] = y_true[1]
        results['bbox3'] = y_true[2]
        results['gt_box1'] = pad_gt_box0
        results['gt_box2'] = pad_gt_box1
        results['gt_box3'] = pad_gt_box2
        return results

    def get_best_anchor_boxes(self, iou, num_layers, grid_shapes, true_boxes, y_true):
        """Get best anchor boxes."""
        best_anchor = np.argmax(iou, axis=-1)
        for t, n in enumerate(best_anchor):
            for layer in range(num_layers):
                if n in self.anchor_mask[layer]:
                    i = np.floor(true_boxes[t, 0] * grid_shapes[layer][1]).astype('int32')  # grid_y
                    j = np.floor(true_boxes[t, 1] * grid_shapes[layer][0]).astype('int32')  # grid_x

                    k = self.anchor_mask[layer].index(n)
                    c = true_boxes[t, 4].astype('int32')
                    y_true[layer][j, i, k, 0:4] = true_boxes[t, 0:4]
                    y_true[layer][j, i, k, 4] = 1.

                    if self.label_smooth:
                        sigma = self.label_smooth_factor / (self.num_classes - 1)
                        y_true[layer][j, i, k, 5:] = sigma
                        y_true[layer][j, i, k, 5 + c] = 1 - self.label_smooth_factor
                    else:
                        y_true[layer][j, i, k, 5 + c] = 1.

        threshold_anchor = (iou > self.iou_threshold)
        for t in range(threshold_anchor.shape[0]):
            for n in range(threshold_anchor.shape[1]):
                if not threshold_anchor[t][n]:
                    continue
                for layer in range(num_layers):
                    if n not in self.anchor_mask[layer]:
                        continue

                    i = np.floor(true_boxes[t, 0] * grid_shapes[layer][1]).astype('int32')  # grid_y
                    j = np.floor(true_boxes[t, 1] * grid_shapes[layer][0]).astype('int32')  # grid_x

                    k = self.anchor_mask[layer].index(n)
                    c = true_boxes[t, 4].astype('int32')
                    y_true[layer][j, i, k, 0:4] = true_boxes[t, 0:4]
                    y_true[layer][j, i, k, 4] = 1.

                    if self.label_smooth:
                        sigma = self.label_smooth_factor / (self.num_classes - 1)
                        y_true[layer][j, i, k, 5:] = sigma
                        y_true[layer][j, i, k, 5 + c] = 1 - self.label_smooth_factor
                    else:
                        y_true[layer][j, i, k, 5 + c] = 1.
        return y_true

    def pad_gt_boxes(self, y_true):
        """Pad ground truth bboxes for avoiding dynamic shape."""
        pad_gt_box0 = np.zeros(shape=[self.max_boxes, 4], dtype=np.float32)
        pad_gt_box1 = np.zeros(shape=[self.max_boxes, 4], dtype=np.float32)
        pad_gt_box2 = np.zeros(shape=[self.max_boxes, 4], dtype=np.float32)

        mask0 = np.reshape(y_true[0][..., 4:5], [-1])
        gt_box0 = np.reshape(y_true[0][..., 0:4], [-1, 4])
        gt_box0 = gt_box0[mask0 == 1]
        if gt_box0.shape[0] < self.max_boxes:
            pad_gt_box0[:gt_box0.shape[0]] = gt_box0
        else:
            pad_gt_box0 = gt_box0[:self.max_boxes]

        mask1 = np.reshape(y_true[1][..., 4:5], [-1])
        gt_box1 = np.reshape(y_true[1][..., 0:4], [-1, 4])
        gt_box1 = gt_box1[mask1 == 1]
        if gt_box1.shape[0] < self.max_boxes:
            pad_gt_box1[:gt_box1.shape[0]] = gt_box1
        else:
            pad_gt_box1 = gt_box1[:self.max_boxes]

        mask2 = np.reshape(y_true[2][..., 4:5], [-1])
        gt_box2 = np.reshape(y_true[2][..., 0:4], [-1, 4])

        gt_box2 = gt_box2[mask2 == 1]
        if gt_box2.shape[0] < self.max_boxes:
            pad_gt_box2[:gt_box2.shape[0]] = gt_box2
        else:
            pad_gt_box2 = gt_box2[:self.max_boxes]

        return [pad_gt_box0, pad_gt_box1, pad_gt_box2]


@ClassFactory.register(ModuleType.PIPELINE)
class PerBatchMap:
    """Batch data preprocess map.
    Args:
        pipeline : preprocess config.
    """

    def __init__(self, out_orders, multi_scales=None, output_type_dict=None,
                 pipeline=None):
        """Constructor for PerBatchMap."""
        self.out_orders = out_orders
        self.multi_scales = multi_scales
        self.output_type_dict = output_type_dict
        self.preprocess_pipeline = pipeline
        if self.preprocess_pipeline is not None:
            self.opts = build_transforms(self.preprocess_pipeline)

    def __call__(self, imgs, anno, x1, x2, batch_info):
        """Preprocess pipeline for image."""
        if self.preprocess_pipeline is not None:
            args = (imgs, anno, x1, x2, batch_info)
            return self.data_augment(*args)

        return None


    def tuple_batch(self, result_dict):
        """Transform tuple type for batch results."""
        collect = Collect(self.out_orders, self.output_type_dict)
        return collect(result_dict)

    def data_augment(self, *args):
        """Data augment for dataset."""
        if self.preprocess_pipeline is None:
            return args

        images = args[0]
        annos = args[1]
        mosaic_flag = args[3]
        result_dict = {}
        resize_size = np.random.choice(self.multi_scales)

        for img, anno, mosaic_flag in zip(images, annos, mosaic_flag):
            results = (img, anno, resize_size, mosaic_flag)
            for opt in self.opts:
                results = opt(results)
            merge_batch(results, result_dict)

        return self.tuple_batch(result_dict)


def merge_batch(results, result_dict):
    """Merge batch results."""
    for k in results:
        if k in result_dict:
            result_dict[k].append(results[k])
        else:
            result_dict[k] = [results[k]]


def pil_image_reshape(interp):
    """Reshape pil image."""
    reshape_type = {
        0: Image.NEAREST,
        1: Image.BILINEAR,
        2: Image.BICUBIC,
        3: Image.NEAREST,
        4: Image.LANCZOS,
    }
    return reshape_type[interp]


def get_interp_method(interp, sizes=()):
    """
    Get the interpolation method for resize functions.
    The major purpose of this function is to wrap a random interp method selection
    and a auto-estimation method.

    Note:
        When shrinking an image, it will generally look best with AREA-based
        interpolation, whereas, when enlarging an image, it will generally look best
        with Bicubic or Bilinear.

    Args:
        interp (int): Interpolation method for all resizing operations.

            - 0: Nearest Neighbors Interpolation.
            - 1: Bilinear interpolation.
            - 2: Bicubic interpolation over 4x4 pixel neighborhood.
            - 3: Nearest Neighbors. Originally it should be Area-based, as we cannot find Area-based,
              so we use NN instead. Area-based (resampling using pixel area relation).
              It may be a preferred method for image decimation, as it gives moire-free results.
              But when the image is zoomed, it is similar to the Nearest Neighbors method. (used by default).
            - 4: Lanczos interpolation over 8x8 pixel neighborhood.
            - 9: Cubic for enlarge, area for shrink, bilinear for others.
            - 10: Random select from interpolation method mentioned above.

        sizes (tuple): Format should like (old_height, old_width, new_height, new_width),
            if None provided, auto(9) will return Area(2) anyway. Default: ()

    Returns:
        int, interp method from 0 to 4.
    """
    if interp == 9:
        if sizes:
            assert len(sizes) == 4
            oh, ow, nh, nw = sizes
            if nh > oh and nw > ow:
                return 2
            if nh < oh and nw < ow:
                return 0
            return 1
        return 2
    if interp == 10:
        return np.random.randint(0, 4)
    if interp not in (0, 1, 2, 3, 4):
        raise ValueError('Unknown interp method %d.' % interp)
    return interp


def rescale_with_scale(img, scale):
    """Rescale image with scale value."""
    h, w = img.shape[:2]
    scale_factor = min(max(scale) / max(h, w), min(scale) / min(h, w))
    new_size = int(w * float(scale_factor) + 0.5), int(h * float(scale_factor) + 0.5)
    rescaled_img = cv2.resize(img, new_size, interpolation=cv2.INTER_LINEAR)

    return rescaled_img, scale_factor


def rescale_with_factor(img, scale_factor):
    """Rescale image with scale factor"""
    h, w = img.shape[:2]
    new_size = int(w * float(scale_factor) + 0.5), int(h * float(scale_factor) + 0.5)
    return cv2.resize(img, new_size, interpolation=cv2.INTER_NEAREST)


@ClassFactory.register(ModuleType.PIPELINE)
class RandomCrop:
    """Rescale operation for image."""

    def __init__(self):
        pass

    def intersect(self, box_a, box_b):
        """Compute the intersect of two sets of boxes."""
        max_yx = np.minimum(box_a[:, 2:4], box_b[2:4])
        min_yx = np.maximum(box_a[:, :2], box_b[:2])
        inter = np.clip((max_yx - min_yx), a_min=0, a_max=np.inf)
        return inter[:, 0] * inter[:, 1]

    def jaccard_numpy(self, box_a, box_b):
        """Compute the jaccard overlap of two sets of boxes."""
        inter = self.intersect(box_a, box_b)
        area_a = ((box_a[:, 2] - box_a[:, 0]) *
                  (box_a[:, 3] - box_a[:, 1]))
        area_b = ((box_b[2] - box_b[0]) *
                  (box_b[3] - box_b[1]))
        union = area_a + area_b - inter
        return inter / union

    def __call__(self, img_id, image, boxes):
        height, width, _ = image.shape
        min_iou = np.random.choice([None, 0.1, 0.3, 0.5, 0.7, 0.9])

        if min_iou is None:
            return img_id, image, boxes

        # max trails (50)
        for _ in range(50):
            image_t = image

            w = rand_init(0.3, 1.0) * width
            h = rand_init(0.3, 1.0) * height

            # aspect ratio constraint b/t .5 & 2
            if h / w < 0.5 or h / w > 2:
                continue

            left = rand_init() * (width - w)
            top = rand_init() * (height - h)

            rect = np.array([int(left), int(top), int(left + w), int(top + h)])
            overlap = self.jaccard_numpy(boxes, rect)

            # dropout some boxes
            drop_mask = overlap > 0
            if not drop_mask.any():
                continue

            if overlap[drop_mask].min() < min_iou and overlap[drop_mask].max() > (min_iou + 0.2):
                continue

            image_t = image_t[rect[1]:rect[3], rect[0]:rect[2], :]

            centers = (boxes[:, :2] + boxes[:, 2:4]) / 2.0

            m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1])
            m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1])

            # mask in that both m1 and m2 are true
            mask = m1 * m2 * drop_mask

            # have any valid boxes? try again if not
            if not mask.any():
                continue

            # take only matching gt boxes
            boxes_t = boxes[mask, :].copy()

            boxes_t[:, :2] = np.maximum(boxes_t[:, :2], rect[:2])
            boxes_t[:, :2] -= rect[:2]
            boxes_t[:, 2:4] = np.minimum(boxes_t[:, 2:4], rect[2:4])
            boxes_t[:, 2:4] -= rect[:2]

            return img_id, image_t, boxes_t
        return img_id, image, boxes


@ClassFactory.register(ModuleType.PIPELINE)
class ImgResize:
    """Rescale operation for image."""

    def __init__(self, img_width, img_height):
        self.img_width = img_width
        self.img_height = img_height

    def __call__(self, img_id, image, boxes):
        ih, iw, _ = image.shape

        image = cv2.resize(image, (self.img_width, self.img_height))

        img_shape = (ih, iw, 1)
        img_shape = np.asarray(img_shape, dtype=np.float32)
        boxes[:, [0, 2]] = boxes[:, [0, 2]] / iw
        boxes[:, [1, 3]] = boxes[:, [1, 3]] / ih
        return img_id, image, img_shape, boxes


@ClassFactory.register(ModuleType.PIPELINE)
class _RandomFlip:
    """Random flip operation."""

    def __init__(self, flip_ratio, is_normalized=False):
        self.flip_ratio = flip_ratio
        self.is_normalized = is_normalized

    def __call__(self, img_id, image, image_shape, boxes):
        flip = (np.random.rand() < self.flip_ratio)
        if not flip:
            return img_id, image, image_shape, boxes

        # flip image
        img_data = image
        img_data = np.flip(img_data, axis=1)

        # flip bboxes
        flipped = boxes.copy()
        _, w, _ = img_data.shape
        # flip bboxes horizontal
        if self.is_normalized:
            flipped[:, [0, 2]] = 1 - flipped[:, [2, 0]]
        else:
            flipped[..., 0::4] = w - boxes[..., 2::4] - 1
            flipped[..., 2::4] = w - boxes[..., 0::4] - 1

        return img_id, img_data, image_shape, flipped


@ClassFactory.register(ModuleType.PIPELINE)
class BoxEncode:
    """Expand operation for image."""

    def __init__(self, img_shape, steps, anchor_size, feature_size, scales, aspect_ratios, num_default,
                 is_training=True):
        generat_boxes = GeneratDefaultBoxes(img_shape, steps, anchor_size, feature_size, scales, aspect_ratios,
                                            num_default)
        self.default_boxes_ltrb = generat_boxes.default_boxes_ltrb
        self.default_boxes = generat_boxes.default_boxes
        self.x1, self.y1, self.x2, self.y2 = np.split(self.default_boxes_ltrb[:, :4], 4, axis=-1)
        self.vol_anchors = (self.x2 - self.x1) * (self.y2 - self.y1)
        self.matching_threshold = 0.5
        self.num_retinanet_boxes = 67995
        self.prior_scaling = [0.1, 0.2]
        self.is_training = is_training

    def jaccard_with_anchors(self, bbox):
        """Compute jaccard score a box and the anchors."""
        # Intersection bbox and volume.
        xmin = np.maximum(self.x1, bbox[0])
        ymin = np.maximum(self.y1, bbox[1])
        xmax = np.minimum(self.x2, bbox[2])
        ymax = np.minimum(self.y2, bbox[3])
        w = np.maximum(xmax - xmin, 0.)
        h = np.maximum(ymax - ymin, 0.)

        # Volumes.
        inter_vol = h * w
        union_vol = self.vol_anchors + (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) - inter_vol
        jaccard = inter_vol / union_vol
        return np.squeeze(jaccard)

    def __call__(self, img_id, image, image_shape, boxes):
        pre_scores = np.zeros((self.num_retinanet_boxes), dtype=np.float32)
        t_boxes = np.zeros((self.num_retinanet_boxes, 4), dtype=np.float32)
        t_label = np.zeros((self.num_retinanet_boxes), dtype=np.int64)

        for bbox in boxes:
            label = int(bbox[4])
            scores = self.jaccard_with_anchors(bbox)
            idx = np.argmax(scores)
            scores[idx] = 2.0
            mask = (scores > self.matching_threshold)
            mask = mask & (scores > pre_scores)
            pre_scores = np.maximum(pre_scores, scores * mask)
            t_label = mask * label + (1 - mask) * t_label
            for i in range(4):
                t_boxes[:, i] = mask * bbox[i] + (1 - mask) * t_boxes[:, i]

        index = np.nonzero(t_label)

        # Transform to ltrb.
        bboxes = np.zeros((self.num_retinanet_boxes, 4), dtype=np.float32)
        bboxes[:, [0, 1]] = (t_boxes[:, [0, 1]] + t_boxes[:, [2, 3]]) / 2
        bboxes[:, [2, 3]] = t_boxes[:, [2, 3]] - t_boxes[:, [0, 1]]

        # Encode features.
        bboxes_t = bboxes[index]
        default_boxes_t = self.default_boxes[index]
        bboxes_t[:, :2] = (bboxes_t[:, :2] - default_boxes_t[:, :2]) / (default_boxes_t[:, 2:] * self.prior_scaling[0])
        tmp = np.maximum(bboxes_t[:, 2:4] / default_boxes_t[:, 2:4], 0.000001)
        bboxes_t[:, 2:4] = np.log(tmp) / self.prior_scaling[1]
        bboxes[index] = bboxes_t

        num_match = np.array([len(np.nonzero(t_label)[0])], dtype=np.int32)
        if self.is_training:
            return image, image_shape, bboxes, t_label.astype(np.int32), num_match
        return image, image_shape, self.default_boxes, img_id


class GeneratDefaultBoxes():
    """
    Generate Default boxes for retinanet, follows the order of (W, H, archor_sizes).
    `self.default_boxes` has a shape of [archor_sizes, H, W, 4], the last dimension is [y, x, h, w].
    `self.default_boxes_ltrb` has a shape as `self.default_boxes`, the last dimension is [y1, x1, y2, x2].
    """

    def __init__(self, img_shape, steps, anchor_size, feature_sizes, scales, aspect_ratios, num_default):
        self.img_shape = img_shape
        self.steps = steps
        self.anchor_size = anchor_size
        self.feature_sizes = feature_sizes
        fk = self.img_shape[0] / np.array(self.steps)
        # scales = np.array([2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)])
        scales = np.array(scales)
        anchor_size = np.array(self.anchor_size)
        self.default_boxes = []
        self.aspect_ratios = aspect_ratios
        self.num_default = num_default
        for idex, feature_size in enumerate(self.feature_sizes):
            base_size = anchor_size[idex] / self.img_shape[0]
            size1 = base_size * scales[0]
            size2 = base_size * scales[1]
            size3 = base_size * scales[2]
            all_sizes = []
            for aspect_ratio in self.aspect_ratios[idex]:
                w1, h1 = size1 * math.sqrt(aspect_ratio), size1 / math.sqrt(aspect_ratio)
                all_sizes.append((h1, w1))
                w2, h2 = size2 * math.sqrt(aspect_ratio), size2 / math.sqrt(aspect_ratio)
                all_sizes.append((h2, w2))
                w3, h3 = size3 * math.sqrt(aspect_ratio), size3 / math.sqrt(aspect_ratio)
                all_sizes.append((h3, w3))

            assert len(all_sizes) == self.num_default[idex]

            for i, j in it.product(range(feature_size), repeat=2):
                for h, w in all_sizes:
                    cx, cy = (j + 0.5) / fk[idex], (i + 0.5) / fk[idex]
                    self.default_boxes.append([cx, cy, w, h])

        def to_ltrb(cx, cy, w, h):
            return cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2

        # For IoU calculation
        self.default_boxes_ltrb = np.array(tuple(to_ltrb(*i) for i in self.default_boxes), dtype='float32')
        self.default_boxes = np.array(self.default_boxes, dtype='float32')


@ClassFactory.register(ModuleType.PIPELINE)
class SsdResizeShape:
    """resize image shape for ssd"""
    def __init__(self):
        pass

    def __call__(self, image, img_id, image_shape, box):
        image_size = (640, 640)
        img_h, img_w, _ = image.shape
        input_h, input_w = image_size
        image = cv2.resize(image, (input_w, input_h))
        # When the channels of image is 1
        if len(image.shape) == 2:
            image = np.expand_dims(image, axis=-1)
            image = np.concatenate([image, image, image], axis=-1)

        image_shape = np.array((img_h, img_w), np.float32)

        return image, image_shape, img_id


@ClassFactory.register(ModuleType.PIPELINE)
class CenterfaceResize:
    """Resize for Pil operation."""

    def __init__(self, scale, split, input_h=832, input_w=832, fix_res=True, rand_crop=True,
                 down_ratio=4, shift=0.1, rotate=0, flip=0.5, output_res=128, input_res=512):
        """Constructor for PilResize."""
        self.input_h = input_h
        self.input_w = input_w
        self.scale = scale
        self.fix_res = fix_res
        self.split = split
        self.down_ratio = down_ratio
        self.rand_crop = rand_crop
        self.shift = shift
        self.rotate = rotate
        self.flip = flip
        self.output_res = output_res
        self.input_res = input_res

    def __call__(self, results):
        """ Do resize. """
        rot = 0
        image = results['image']
        if self.split == 'train':
            flipped = False
            height, width = image.shape[0], image.shape[1]
            c = np.array([width / 2., height / 2.], dtype=np.float32)
            s = max(height, width) * 1.0
            if self.rand_crop:
                # s = s * np.random.choice(np.arange(0.8, 1.3, 0.05)) # for 768*768 or 800* 800
                s = s * np.random.choice(np.arange(0.6, 1.0, 0.05))  # for 512 * 512
                border = s * np.random.choice([0.1, 0.2, 0.25])
                w_border = self.get_border(border, width)  # w > 2 * w_border
                h_border = self.get_border(border, height)  # h > 2 * h_border
                c[0] = np.random.randint(low=w_border, high=width - w_border)
                c[1] = np.random.randint(low=h_border, high=height - h_border)
            else:
                sf = self.scale
                cf = self.shift
                c[0] += s * np.clip(np.random.randn() * cf, -2 * cf, 2 * cf)
                c[1] += s * np.clip(np.random.randn() * cf, -2 * cf, 2 * cf)
                s = s * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf)
            if np.random.random() < self.rotate:
                rf = self.rotate
                rot = np.clip(np.random.randn() * rf, -rf * 2, rf * 2)
            if np.random.random() < self.flip:  # opt.flip = 0.5
                flipped = True
                image = image[:, ::-1, :]
                c[0] = width - c[0] - 1
            trans_input = get_affine_transform(c, s, rot, [self.input_res, self.input_res])
            inp_image = cv2.warpAffine(
                image, trans_input, (self.input_res, self.input_res),
                flags=cv2.INTER_LINEAR)
            trans_output_rot = get_affine_transform(c, s, rot, [self.output_res, self.output_res])
            trans_output = get_affine_transform(c, s, 0, [self.output_res, self.output_res])

            results['trans_output_rot'] = trans_output_rot
            results['trans_output'] = trans_output
            results['rot'] = rot
            results['flipped'] = flipped
            results['width'] = width
            results['image'] = inp_image

        elif self.split == 'test':
            height, width = results['image_shape']
            new_height = int(height * self.scale)
            new_width = int(width * self.scale)
            if self.fix_res:  # True
                inp_height, inp_width = self.input_h, self.input_w
                c = np.array([new_width / 2., new_height / 2.], dtype=np.float32)
                s = max(height, width) * 1.0
            else:
                inp_height = int(np.ceil(new_height / 32) * 32)
                inp_width = int(np.ceil(new_width / 32) * 32)
                c = np.array([new_width // 2, new_height // 2], dtype=np.float32)
                s = np.array([inp_width, inp_height], dtype=np.float32)

            results['c'] = np.array(c)
            results['s'] = np.array(s)
            results['out_height'] = np.array(inp_height // self.down_ratio)
            results['out_width'] = np.array(inp_width // self.down_ratio)
            trans_input = get_affine_transform(c, s, rot, [inp_width, inp_height])
            resized_image = cv2.resize(image, (new_width, new_height))
            inp_image = cv2.warpAffine(
                resized_image, trans_input, (inp_width, inp_height),
                flags=cv2.INTER_LINEAR)
            results['image'] = inp_image

        return results

    def get_border(self, border, size):
        """
        Get border
        """
        i = 1
        while size - border // i <= border // i:  # size > 2 * (border // i)
            i *= 2
        return border // i


@ClassFactory.register(ModuleType.PIPELINE)
class ColorAug:
    """Use for np.uint8."""

    def __init__(self, eig_val, eig_vec, scale):
        self.eig_val = np.array(eig_val)
        self.eig_vec = np.array(eig_vec)
        self.scale = scale

    def __call__(self, results):
        image = results['image']
        data_rng = np.random.RandomState(123)
        flag = True

        if flag:
            self.color_aug(data_rng, image, self.eig_val, self.eig_vec)

        results["image"] = image
        return results

    def color_aug(self, data_rng, image, eig_val, eig_vec):
        """color_aug"""
        functions = [self.brightness_, self.contrast_, self.saturation_]
        np.random.shuffle(functions)

        gs = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        gs_mean = gs.mean()
        for f in functions:
            f(data_rng, image, gs, gs_mean, 0.4)
        self.lighting_(data_rng, image, 0.1, eig_val, eig_vec)

    def lighting_(self, data_rng, image, alphastd, eigval, eigvec):
        """lighting_"""
        alpha = data_rng.normal(scale=alphastd, size=(3,))
        image += np.dot(eigvec, eigval * alpha)

    def blend_(self, alpha, image1, image2):
        """blend_"""
        image1 *= alpha
        image2 *= (1 - alpha)
        image1 += image2

    def saturation_(self, data_rng, image, gs, gs_mean, var):
        """saturation_"""
        gs_mean = gs_mean
        alpha = 1. + data_rng.uniform(low=-var, high=var)
        self.blend_(alpha, image, gs[:, :, None])

    def brightness_(self, data_rng, image, gs, gs_mean, var):
        """brightness_"""
        gs = gs
        gs_mean = gs_mean
        alpha = 1. + data_rng.uniform(low=-var, high=var)
        image *= alpha

    def contrast_(self, data_rng, image, gs, gs_mean, var):
        """contrast_"""
        gs = gs
        alpha = 1. + data_rng.uniform(low=-var, high=var)
        self.blend_(alpha, image, gs_mean)


@ClassFactory.register(ModuleType.PIPELINE)
class CenterfaceCropPreprocess:
    """
    Crop anchors.
    """

    def __init__(self, max_size, inf_distance, anchor_idx, anchors):
        self.max_size = max_size
        self.inf_distance = inf_distance
        self.anchor_idx = anchor_idx
        self.anchors = anchors

    def __call__(self, results):
        anns = results['annotation']
        image = results['image']
        cv2.setNumThreads(0)
        boxes = []
        for ann in anns:
            boxes.append([ann['bbox'][0], ann['bbox'][1], ann['bbox'][0] \
                          + ann['bbox'][2], ann['bbox'][1] + ann['bbox'][3]])
        boxes = np.asarray(boxes, dtype=np.float32)

        height, width = image.shape[0], image.shape[1]

        box_area = (boxes[:, 2] - boxes[:, 0] + 1) * (boxes[:, 3] - boxes[:, 1] + 1)
        rand_idx = np.random.randint(0, len(box_area))
        rand_side = box_area[rand_idx] ** 0.5

        distance = self.inf_distance
        anchor_idx = self.anchor_idx
        for i, anchor in enumerate(self.anchors):
            if abs(anchor - rand_side) < distance:
                distance = abs(anchor - rand_side)
                anchor_idx = i

        target_anchor = np.random.choice(self.anchors[0:min(anchor_idx + 1, 11)])
        ratio = float(target_anchor) / rand_side
        ratio = ratio * (2 ** np.random.uniform(-1, 1))

        if int(height * ratio * width * ratio) > self.max_size * self.max_size:
            ratio = (self.max_size * self.max_size / (height * width)) ** 0.5

        interp_methods = [cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_NEAREST, cv2.INTER_LANCZOS4]
        interp_method = np.random.choice(interp_methods)
        image = cv2.resize(image, None, None, fx=ratio, fy=ratio, interpolation=interp_method)

        boxes[:, 0] *= ratio
        boxes[:, 1] *= ratio
        boxes[:, 2] *= ratio
        boxes[:, 3] *= ratio

        boxes = boxes.tolist()
        for i, _ in enumerate(anns):
            anns[i]['bbox'] = [boxes[i][0], boxes[i][1], boxes[i][2] - boxes[i][0], boxes[i][3] - boxes[i][1]]
            for j in range(5):
                anns[i]['keypoints'][j * 3] *= ratio
                anns[i]['keypoints'][j * 3 + 1] *= ratio
        results['annotation'] = anns
        results['image'] = image
        return results


@ClassFactory.register(ModuleType.PIPELINE)
class CenterfaceBboxPreprocess:
    """ Data bbox preprocess for centerface."""

    def __init__(self, output_res, num_joints, max_objs, num_classes, flip_idx):
        self.output_res = output_res
        self.num_joints = num_joints
        self.max_objs = max_objs
        self.num_classes = num_classes
        self.flip_idx = flip_idx

    def __call__(self, results):
        flipped = results['flipped']
        image = results["image"]
        anns = results["annotation"]
        width = results['width']
        rot = results["rot"]
        trans_output_rot = results["trans_output_rot"]
        trans_output = results["trans_output"]
        num_objs = results['num_objs']
        # map
        hm = np.zeros((self.num_classes, self.output_res, self.output_res), dtype=np.float32)
        hm_hp = np.zeros((self.num_joints, self.output_res, self.output_res), dtype=np.float32)

        wh = np.zeros((self.output_res, self.output_res, 2), dtype=np.float32)
        reg = np.zeros((self.output_res, self.output_res, 2), dtype=np.float32)
        ind = np.zeros((self.output_res, self.output_res),
                       dtype=np.float32)  # as float32, need no data_type change later

        reg_mask = np.zeros((self.max_objs), dtype=np.uint8)
        wight_mask = np.zeros((self.output_res, self.output_res, 2), dtype=np.float32)

        kps = np.zeros((self.output_res, self.output_res, self.num_joints * 2), dtype=np.float32)
        kps_mask = np.zeros((self.output_res, self.output_res, self.num_joints * 2), dtype=np.float32)
        #
        hp_offset = np.zeros((self.max_objs * self.num_joints, 2), dtype=np.float32)
        hp_ind = np.zeros((self.max_objs * self.num_joints), dtype=np.int64)
        hp_mask = np.zeros((self.max_objs * self.num_joints), dtype=np.int64)

        draw_gaussian = draw_umich_gaussian

        gt_det = []
        for k in range(num_objs):
            ann = anns[k]
            bbox = self.coco_box_to_bbox(ann['bbox'])  # [x,y,w,h]--[x1,y1,x2,y2]
            cls_id = 0  # int(ann['category_id']) - 1
            pts = np.array(ann['keypoints'], np.float32).reshape(self.num_joints, 3)  # (x,y,0/1)
            if flipped:
                bbox[[0, 2]] = width - bbox[[2, 0]] - 1
                pts[:, 0] = width - pts[:, 0] - 1
                for e in self.flip_idx:  # flip_idx = [[0, 1], [3, 4]]
                    pts[e[0]], pts[e[1]] = pts[e[1]].copy(), pts[e[0]].copy()

            bbox[:2] = affine_transform(bbox[:2], trans_output)  # [0, 1] -- (x1, y1)
            bbox[2:] = affine_transform(bbox[2:], trans_output)  # [2, 3] -- (x2, y2)
            bbox = np.clip(bbox, 0, self.output_res - 1)
            h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
            if (h > 0 and w > 0) or (rot != 0):
                radius = gaussian_radius((math.ceil(h), math.ceil(w)))
                radius = max(0, int(radius))
                ct = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2], dtype=np.float32)
                ct_int = ct.astype(np.int32)

                ind[ct_int[1], ct_int[0]] = 1.0
                wh[ct_int[1], ct_int[0], :] = np.log(1. * w / 4), np.log(1. * h / 4)
                reg[ct_int[1], ct_int[0], :] = ct[0] - ct_int[0], ct[1] - ct_int[1]

                reg_mask[k] = 1.0
                wight_mask[ct_int[1], ct_int[0], 0] = 1
                wight_mask[ct_int[1], ct_int[0], 1] = 1

                # if w*h <= 20: # can get what we want sometime, but unstable
                #     wight_mask[k] = 15
                if w * h <= 40:
                    wight_mask[ct_int[1], ct_int[0], 0] = 5
                    wight_mask[ct_int[1], ct_int[0], 1] = 5
                if w * h <= 20:
                    wight_mask[ct_int[1], ct_int[0], 0] = 10
                    wight_mask[ct_int[1], ct_int[0], 1] = 10
                if w * h <= 10:
                    wight_mask[ct_int[1], ct_int[0], 0] = 15
                    wight_mask[ct_int[1], ct_int[0], 1] = 15
                if w * h <= 4:
                    wight_mask[ct_int[1], ct_int[0], 0] = 0.1
                    wight_mask[ct_int[1], ct_int[0], 1] = 0.1

                num_kpts = pts[:, 2].sum()
                if num_kpts == 0:
                    hm[cls_id, ct_int[1], ct_int[0]] = 0.9999

                hp_radius = gaussian_radius((math.ceil(h), math.ceil(w)))
                hp_radius = max(0, int(hp_radius))
                for j in range(self.num_joints):
                    if pts[j, 2] > 0:
                        pts[j, :2] = affine_transform(pts[j, :2], trans_output_rot)
                        if pts[j, 0] >= 0 and pts[j, 0] < self.output_res and pts[j, 1] >= 0 and pts[j, 1] \
                                < self.output_res:
                            kps[ct_int[1], ct_int[0], j * 2: j * 2 + 2] = pts[j, :2] - ct_int
                            kps[ct_int[1], ct_int[0], j * 2: j * 2 + 1] = kps[ct_int[1], ct_int[0], \
                                                                          j * 2: j * 2 + 1] / w
                            kps[ct_int[1], ct_int[0], j * 2 + 1: j * 2 + 2] = kps[ct_int[1], ct_int[0], \
                                                                              j * 2 + 1: j * 2 + 2] / h
                            kps_mask[ct_int[1], ct_int[0], j * 2: j * 2 + 2] = 1.0

                            pt_int = pts[j, :2].astype(np.int32)
                            hp_offset[k * self.num_joints + j] = pts[j, :2] - pt_int
                            hp_ind[k * self.num_joints + j] = pt_int[1] * self.output_res + pt_int[0]
                            hp_mask[k * self.num_joints + j] = 1

                            draw_gaussian(hm_hp[j], pt_int, hp_radius)
                            kps_mask[ct_int[1], ct_int[0], j * 2: j * 2 + 2] = \
                                0.0 if ann['bbox'][2] * ann['bbox'][3] <= 8.0 else 1.0
                draw_gaussian(hm[cls_id], ct_int, radius)
                gt_det.append([ct[0] - w / 2, ct[1] - h / 2,
                               ct[0] + w / 2, ct[1] + h / 2, 1] +
                              pts[:, :2].reshape(self.num_joints * 2).tolist() + [cls_id])

        results['image'] = image
        results['hm'] = hm
        results['reg_mask'] = reg_mask
        results['ind'] = ind
        results['wh'] = wh
        results['wight_mask'] = wight_mask
        results['hm_offset'] = reg
        results['hps_mask'] = kps_mask
        results['landmarks'] = kps

        return results

    def coco_box_to_bbox(self, box):
        """
        (x1, y1, w, h) -> (x1, y1, x2, y2)
        """
        bbox = np.array([box[0], box[1], box[0] + box[2], box[1] + box[3]], dtype=np.float32)
        return bbox


def affine_transform(pt, t):
    """
    Affine transform
    """
    new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32).T
    new_pt = np.dot(t, new_pt)
    return new_pt[:2]


def get_affine_transform(center, scale, rot, output_size, shift=np.array([0, 0], dtype=np.float32), inv=0):
    """Get affine transform"""
    if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
        scale = np.array([scale, scale], dtype=np.float32)
    scale_tmp = scale
    src_w = scale_tmp[0]
    dst_w = output_size[0]
    dst_h = output_size[1]

    rot_rad = np.pi * rot / 180
    src_point = [0, src_w * -0.5]
    sn, cs = np.sin(rot_rad), np.cos(rot_rad)  # (0, 1)
    src_dir = [0, 0]
    src_dir[0] = src_point[0] * cs - src_point[1] * sn
    src_dir[1] = src_point[0] * sn + src_point[1] * cs
    dst_dir = np.array([0, dst_w * -0.5], np.float32)

    src = np.zeros((3, 2), dtype=np.float32)
    dst = np.zeros((3, 2), dtype=np.float32)
    src[0, :] = center + scale_tmp * shift
    src[1, :] = center + src_dir + scale_tmp * shift
    dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
    dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5], np.float32) + dst_dir

    direct0 = src[0, :] - src[1, :]
    src[2:, :] = src[1, :] + np.array([-direct0[1], direct0[0]], dtype=np.float32)
    direct1 = dst[0, :] - dst[1, :]
    dst[2:, :] = dst[1, :] + np.array([-direct1[1], direct1[0]], dtype=np.float32)

    if inv:
        trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
    else:
        trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))

    return trans


def gaussian_radius(det_size, min_overlap=0.7):
    """
    Gaussian radius
    """
    height, width = det_size

    a1 = 1
    b1 = (height + width)
    c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
    sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)
    r1 = (b1 + sq1) / 2

    a2 = 4
    b2 = 2 * (height + width)
    c2 = (1 - min_overlap) * width * height
    sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)
    r2 = (b2 + sq2) / 2

    a3 = 4 * min_overlap
    b3 = -2 * min_overlap * (height + width)
    c3 = (min_overlap - 1) * width * height
    sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)
    r3 = (b3 + sq3) / 2
    return min(r1, r2, r3)


def gaussian2d(shape, sigma=1):
    """gaussian2d"""
    m, n = [(ss - 1.) / 2. for ss in shape]
    y, x = np.ogrid[-m:m + 1, -n:n + 1]
    h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
    h[h < np.finfo(h.dtype).eps * h.max()] = 0
    return h


def draw_umich_gaussian(heatmap, center, radius, k=1):
    """
    Draw umich gaussian
    """
    diameter = 2 * radius + 1
    gaussian = gaussian2d((diameter, diameter), sigma=diameter / 6)
    x, y = int(center[0]), int(center[1])
    height, width = heatmap.shape[0:2]
    left, right = min(x, radius), min(width - x, radius + 1)
    top, bottom = min(y, radius), min(height - y, radius + 1)
    masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
    masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right]
    if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0:
        np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
    return heatmap

@ClassFactory.register(ModuleType.PIPELINE)
class Concatenate:
    """Use PIL Apis to do flip operations.
    Args:
        direction : Only support "FLIP_LEFT_RIGHT" and "FLIP_TOP_BOTTOM".
    """

    def __init__(self):
        """Constructor for Concatenate."""

    def __call__(self, image):
        image = np.concatenate((image[..., ::2, ::2], image[..., 1::2, ::2],
                                image[..., ::2, 1::2], image[..., 1::2, 1::2]), axis=0)

        return image


@ClassFactory.register(ModuleType.PIPELINE)
class _Concatenate:
    """Use PIL Apis to do flip operations.
    Args:
        direction : Only support "FLIP_LEFT_RIGHT" and "FLIP_TOP_BOTTOM".
    """

    def __init__(self):
        """Constructor for Concatenate."""

    def __call__(self, results):
        image = results['image']
        image = np.concatenate((image[..., ::2, ::2], image[..., 1::2, ::2],
                                image[..., ::2, 1::2], image[..., 1::2, 1::2]), axis=0)

        results['image'] = image
        return results


def _preprocess_true_boxes(true_boxes, anchors, in_shape, num_classes, max_boxes, label_smooth,
                           label_smooth_factor=0.1, iou_threshold=0.213):
    """
    Introduction
    ------------
        preprocessing ground truth box
    Parameters
    ----------
        true_boxes: ground truth box shape as [boxes, 5], x_min, y_min, x_max, y_max, class_id
    """
    anchors = np.array(anchors)
    num_layers = anchors.shape[0] // 3
    anchor_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
    true_boxes = np.array(true_boxes, dtype='float32')
    input_shape = np.array(in_shape, dtype='int32')
    boxes_xy = (true_boxes[..., 0:2] + true_boxes[..., 2:4]) // 2.
    # trans to box center point
    boxes_wh = true_boxes[..., 2:4] - true_boxes[..., 0:2]
    # input_shape is [h, w]
    true_boxes[..., 0:2] = boxes_xy / input_shape[::-1]
    true_boxes[..., 2:4] = boxes_wh / input_shape[::-1]
    # true_boxes = [xywh]
    grid_shapes = [input_shape // 32, input_shape // 16, input_shape // 8]
    # grid_shape [h, w]
    y_true = [np.zeros((grid_shapes[l][0], grid_shapes[l][1], len(anchor_mask[l]),
                        5 + num_classes), dtype='float32') for l in range(num_layers)]
    # y_true [gridy, gridx]
    anchors = np.expand_dims(anchors, 0)
    anchors_max = anchors / 2.
    anchors_min = -anchors_max
    valid_mask = boxes_wh[..., 0] > 0
    wh = boxes_wh[valid_mask]
    if wh.size != 0:
        wh = np.expand_dims(wh, -2)
        # wh shape[box_num, 1, 2]
        boxes_max = wh / 2.
        boxes_min = -boxes_max
        intersect_min = np.maximum(boxes_min, anchors_min)
        intersect_max = np.minimum(boxes_max, anchors_max)
        intersect_wh = np.maximum(intersect_max - intersect_min, 0.)
        intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
        box_area = wh[..., 0] * wh[..., 1]
        anchor_area = anchors[..., 0] * anchors[..., 1]
        iou = intersect_area / (box_area + anchor_area - intersect_area)

        # topk iou
        topk = 4
        topk_flag = iou.argsort()
        topk_flag = topk_flag >= topk_flag.shape[1] - topk
        flag = topk_flag.nonzero()
        for index in range(len(flag[0])):
            t = flag[0][index]
            n = flag[1][index]
            if iou[t][n] < iou_threshold:
                continue
            for l in range(num_layers):
                if n in anchor_mask[l]:
                    i = np.floor(true_boxes[t, 0] * grid_shapes[l][1]).astype('int32')  # grid_y
                    j = np.floor(true_boxes[t, 1] * grid_shapes[l][0]).astype('int32')  # grid_x

                    k = anchor_mask[l].index(n)
                    c = true_boxes[t, 4].astype('int32')
                    y_true[l][j, i, k, 0:4] = true_boxes[t, 0:4]
                    y_true[l][j, i, k, 4] = 1.

                    # lable-smooth
                    if label_smooth:
                        sigma = label_smooth_factor / (num_classes - 1)
                        y_true[l][j, i, k, 5:] = sigma
                        y_true[l][j, i, k, 5 + c] = 1 - label_smooth_factor
                    else:
                        y_true[l][j, i, k, 5 + c] = 1.
        # best anchor for gt
        best_anchor = np.argmax(iou, axis=-1)
        for t, n in enumerate(best_anchor):
            for l in range(num_layers):
                if n in anchor_mask[l]:
                    i = np.floor(true_boxes[t, 0] * grid_shapes[l][1]).astype('int32')  # grid_y
                    j = np.floor(true_boxes[t, 1] * grid_shapes[l][0]).astype('int32')  # grid_x

                    k = anchor_mask[l].index(n)
                    c = true_boxes[t, 4].astype('int32')
                    y_true[l][j, i, k, 0:4] = true_boxes[t, 0:4]
                    y_true[l][j, i, k, 4] = 1.

                    # lable-smooth
                    if label_smooth:
                        sigma = label_smooth_factor / (num_classes - 1)
                        y_true[l][j, i, k, 5:] = sigma
                        y_true[l][j, i, k, 5 + c] = 1 - label_smooth_factor
                    else:
                        y_true[l][j, i, k, 5 + c] = 1.

    # pad_gt_boxes for avoiding dynamic shape
    pad_gt_box0 = np.zeros(shape=[max_boxes, 4], dtype=np.float32)
    pad_gt_box1 = np.zeros(shape=[max_boxes, 4], dtype=np.float32)
    pad_gt_box2 = np.zeros(shape=[max_boxes, 4], dtype=np.float32)

    mask0 = np.reshape(y_true[0][..., 4:5], [-1])
    gt_box0 = np.reshape(y_true[0][..., 0:4], [-1, 4])
    # gt_box [boxes, [x,y,w,h]]
    gt_box0 = gt_box0[mask0 == 1]
    # gt_box0: get all boxes which have object
    if gt_box0.shape[0] < max_boxes:
        pad_gt_box0[:gt_box0.shape[0]] = gt_box0
    else:
        pad_gt_box0 = gt_box0[:max_boxes]
    # gt_box0.shape[0]: total number of boxes in gt_box0
    # top N of pad_gt_box0 is real box, and after are pad by zero

    mask1 = np.reshape(y_true[1][..., 4:5], [-1])
    gt_box1 = np.reshape(y_true[1][..., 0:4], [-1, 4])
    gt_box1 = gt_box1[mask1 == 1]
    if gt_box1.shape[0] < max_boxes:
        pad_gt_box1[:gt_box1.shape[0]] = gt_box1
    else:
        pad_gt_box1 = gt_box1[:max_boxes]

    mask2 = np.reshape(y_true[2][..., 4:5], [-1])
    gt_box2 = np.reshape(y_true[2][..., 0:4], [-1, 4])

    gt_box2 = gt_box2[mask2 == 1]
    if gt_box2.shape[0] < max_boxes:
        pad_gt_box2[:gt_box2.shape[0]] = gt_box2
    else:
        pad_gt_box2 = gt_box2[:max_boxes]
    return y_true[0], y_true[1], y_true[2], pad_gt_box0, pad_gt_box1, pad_gt_box2


@ClassFactory.register(ModuleType.PIPELINE)
class PreprocessTrueBox:
    """ preprocess true box."""
    def __init__(self,
                 anchors,
                 anchor_mask,
                 num_classes,
                 label_smooth,
                 label_smooth_factor,
                 iou_threshold,
                 max_boxes):
        """Constructor for YoloCollate."""
        self.anchors = anchors
        self.num_classes = num_classes
        self.label_smooth = label_smooth
        self.label_smooth_factor = label_smooth_factor
        self.iou_threshold = iou_threshold
        self.max_boxes = max_boxes
        self.anchor_mask = anchor_mask

    def __call__(self, anno, input_shape):
        bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \
            _preprocess_true_boxes(true_boxes=anno, anchors=self.anchors, in_shape=input_shape,
                                   num_classes=self.num_classes, max_boxes=self.max_boxes,
                                   label_smooth=self.label_smooth, label_smooth_factor=self.label_smooth_factor)

        return anno, np.array(bbox_true_1), np.array(bbox_true_2), np.array(bbox_true_3), \
               np.array(gt_box1), np.array(gt_box2), np.array(gt_box3)


def bbox_iou(bbox_a, bbox_b, offset=0):
    """Calculate Intersection-Over-Union(IOU) of two bounding boxes.

    Parameters
    ----------
    bbox_a : numpy.ndarray
        An ndarray with shape :math:`(N, 4)`.
    bbox_b : numpy.ndarray
        An ndarray with shape :math:`(M, 4)`.
    offset : float or int, default is 0
        The ``offset`` is used to control the whether the width(or height) is computed as
        (right - left + ``offset``).
        Note that the offset must be 0 for normalized bboxes, whose ranges are in ``[0, 1]``.

    Returns
    -------
    numpy.ndarray
        An ndarray with shape :math:`(N, M)` indicates IOU between each pairs of
        bounding boxes in `bbox_a` and `bbox_b`.

    """
    if bbox_a.shape[1] < 4 or bbox_b.shape[1] < 4:
        raise IndexError("Bounding boxes axis 1 must have at least length 4")

    tl = np.maximum(bbox_a[:, None, :2], bbox_b[:, :2])
    br = np.minimum(bbox_a[:, None, 2:4], bbox_b[:, 2:4])

    area_i = np.prod(br - tl + offset, axis=2) * (tl < br).all(axis=2)
    area_a = np.prod(bbox_a[:, 2:4] - bbox_a[:, :2] + offset, axis=1)
    area_b = np.prod(bbox_b[:, 2:4] - bbox_b[:, :2] + offset, axis=1)
    return area_i / (area_a[:, None] + area_b - area_i)


def _is_iou_satisfied_constraint(min_iou, max_iou, box, crop_box):
    iou = bbox_iou(box, crop_box)
    return min_iou <= iou.min() and max_iou >= iou.max()


def _choose_candidate_by_constraints(max_trial, input_w, input_h, image_w, image_h, jitter, box, use_constraints):
    """Choose candidate by constraints."""
    if use_constraints:
        constraints = (
            (0.1, None),
            (0.3, None),
            (0.5, None),
            (0.7, None),
            (0.9, None),
            (None, 1),
        )
    else:
        constraints = ((None, None),)
    # add default candidate
    candidates = [(0, 0, input_w, input_h)]
    for constraint in constraints:
        min_iou, max_iou = constraint
        min_iou = -np.inf if min_iou is None else min_iou
        max_iou = np.inf if max_iou is None else max_iou

        for _ in range(max_trial):
            # box_data should have at least one box
            new_ar = float(input_w) / float(input_h) * rand_init(1 - jitter, 1 + jitter) / \
                    rand_init(1 - jitter, 1 + jitter)
            scale = rand_init(0.5, 2)

            if new_ar < 1:
                nh = int(scale * input_h)
                nw = int(nh * new_ar)
            else:
                nw = int(scale * input_w)
                nh = int(nw / new_ar)

            dx = int(rand_init(0, input_w - nw))
            dy = int(rand_init(0, input_h - nh))

            if box.size > 0:
                t_box = copy.deepcopy(box)
                t_box[:, [0, 2]] = t_box[:, [0, 2]] * float(nw) / float(image_w) + dx
                t_box[:, [1, 3]] = t_box[:, [1, 3]] * float(nh) / float(image_h) + dy

                crop_box = np.array((0, 0, input_w, input_h))
                if not _is_iou_satisfied_constraint(min_iou, max_iou, t_box, crop_box[np.newaxis]):
                    continue
                else:
                    candidates.append((dx, dy, nw, nh))
            else:
                raise Exception("!!! annotation box is less than 1")
    return candidates


def _correct_bbox_by_candidates(candidates, input_w, input_h, image_w,
                                image_h, flip, box, box_data, allow_outside_center, max_boxes):
    """Calculate correct boxes."""
    while candidates:
        if len(candidates) > 1:
            # ignore default candidate which do not crop
            candidate = candidates.pop(np.random.randint(1, len(candidates)))
        else:
            candidate = candidates.pop(np.random.randint(0, len(candidates)))
        dx, dy, nw, nh = candidate
        t_box = copy.deepcopy(box)
        t_box[:, [0, 2]] = t_box[:, [0, 2]] * float(nw) / float(image_w) + dx
        t_box[:, [1, 3]] = t_box[:, [1, 3]] * float(nh) / float(image_h) + dy
        if flip:
            t_box[:, [0, 2]] = input_w - t_box[:, [2, 0]]

        if allow_outside_center:
            pass
        else:
            t_box = t_box[
                np.logical_and((t_box[:, 0] + t_box[:, 2]) / 2. >= 0., (t_box[:, 1] + t_box[:, 3]) / 2. >= 0.)]
            t_box = t_box[np.logical_and((t_box[:, 0] + t_box[:, 2]) / 2. <= input_w,
                                         (t_box[:, 1] + t_box[:, 3]) / 2. <= input_h)]

        # recorrect x, y for case x,y < 0 reset to zero, after dx and dy, some box can smaller than zero
        t_box[:, 0:2][t_box[:, 0:2] < 0] = 0
        # recorrect w,h not higher than input size
        t_box[:, 2][t_box[:, 2] > input_w] = input_w
        t_box[:, 3][t_box[:, 3] > input_h] = input_h
        box_w = t_box[:, 2] - t_box[:, 0]
        box_h = t_box[:, 3] - t_box[:, 1]
        # discard invalid box: w or h smaller than 1 pixel
        t_box = t_box[np.logical_and(box_w > 1, box_h > 1)]

        if t_box.shape[0] > 0:
            # break if number of find t_box
            box_data[: len(t_box)] = t_box
            return box_data, candidate
    return np.zeros(shape=[max_boxes, 5], dtype=np.float64), (0, 0, nw, nh)


def convert_gray_to_color(img):
    """Convert gray to color."""
    if len(img.shape) == 2:
        img = np.expand_dims(img, axis=-1)
        img = np.concatenate([img, img, img], axis=-1)
    return img


def color_distortion(img, hue, sat, val, device_num):
    """Color distortion."""
    hue = rand_init(-hue, hue)
    sat = rand_init(1, sat) if rand_init() < .5 else 1 / rand_init(1, sat)
    val = rand_init(1, val) if rand_init() < .5 else 1 / rand_init(1, val)
    if device_num != 1:
        cv2.setNumThreads(1)
    x = cv2.cvtColor(img, cv2.COLOR_RGB2HSV_FULL)
    x = x / 255.
    x[..., 0] += hue
    x[..., 0][x[..., 0] > 1] -= 1
    x[..., 0][x[..., 0] < 0] += 1
    x[..., 1] *= sat
    x[..., 2] *= val
    x[x > 1] = 1
    x[x < 0] = 0
    x = x * 255.
    x = x.astype(np.uint8)
    image_data = cv2.cvtColor(x, cv2.COLOR_HSV2RGB_FULL)
    return image_data


@ClassFactory.register(ModuleType.PIPELINE)
class MultiScaleTrans:
    """Multi scale transform."""

    def __init__(self,
                 max_boxes,
                 jitter,
                 max_trial,
                 flip=0.5,
                 rgb=(128, 128, 128),
                 use_constraints=False):
        self.max_boxes = max_boxes
        self.jitter = jitter
        self.max_trial = max_trial
        self.flip = flip
        self.rgb = rgb
        self.use_constraints = use_constraints

    def __call__(self, image, anno, input_size, mosaic_flag):

        if mosaic_flag[0] == 0:
            image = PV.Decode()(image)

        if not isinstance(image, Image.Image):
            image = Image.fromarray(image)

        image_w, image_h = image.size
        input_h, input_w = input_size

        np.random.shuffle(anno)
        if len(anno) > self.max_boxes:
            anno = anno[:self.max_boxes]
        flip = rand_init() < self.flip
        box_data = np.zeros((self.max_boxes, 5))

        candidates = _choose_candidate_by_constraints(use_constraints=self.use_constraints, max_trial=self.max_trial,
                                                      input_w=input_w, input_h=input_h, image_w=image_w,
                                                      image_h=image_h, jitter=self.jitter, box=anno)
        box_data, candidate = _correct_bbox_by_candidates(candidates=candidates, input_w=input_w, input_h=input_h,
                                                          image_w=image_w, image_h=image_h, flip=flip, box=anno,
                                                          box_data=box_data, allow_outside_center=True,
                                                          max_boxes=self.max_boxes)
        dx, dy, nw, nh = candidate
        interp = get_interp_method(interp=10)
        image = image.resize((nw, nh), pil_image_reshape(interp))
        # place image, gray color as back graoud
        new_image = Image.new('RGB', (input_w, input_h), self.rgb)
        new_image.paste(image, (dx, dy))
        image = new_image

        if flip:
            image = filp_pil_image(image)

        image = np.array(image)
        image = convert_gray_to_color(image)
        image = color_distortion(image, 0.015, 1.5, 0.4, 8)

        return image, box_data, np.array(image.shape[0:2])

@ClassFactory.register(ModuleType.PIPELINE)
class RetinafaceBboxPreprocess:
    """ Data bbox preprocess for retinaface."""
    def __init__(self, match_thresh=None, variance=None, image_size=None, anchor=None, step=None, clip=None):
        self.match_thresh = match_thresh
        self.variances = variance
        self.priors = prior_box((image_size, image_size), anchor, step, clip)

    def __call__(self, results):
        boxes = results.get('bboxes')
        labels = results.get('labels')
        landms = results.get('landms')
        priors = self.priors
        if self.match_thresh is None:
            results['priors'] = self.priors
            return results
        loc_t, conf_t, landm_t = match(self.match_thresh, boxes, priors, self.variances, labels, landms)
        results['truths'] = loc_t
        results['conf'] = conf_t
        results['landm'] = landm_t
        return results

@ClassFactory.register(ModuleType.PIPELINE)
class RandomCropRetinaface:
    """randomly crop image"""
    def __init__(self, image_input_size=840):
        self.image_input_size = image_input_size

    def __call__(self, results):
        image = results.get('image').astype(np.float32)
        boxes = results.get('bboxes')
        labels = results.get('labels')
        landms = results.get('landms')
        aug_image, aug_target = self._data_aug(image, boxes, labels, landms, self.image_input_size)
        results['image'] = aug_image
        results['bboxes'] = aug_target[:, :4]
        results['labels'] = aug_target[:, -1]
        results['landms'] = aug_target[:, 4:14]
        return results

    def _data_aug(self, image, boxes, labels, landms, image_input_size, max_trial=250):
        """select candidate regions and modify annotation"""
        image_h, image_w, _ = image.shape
        input_h, input_w = image_input_size, image_input_size
        flip = rand_init() < .5
        candidates = proposal_crop_areas(max_trial=max_trial,
                                         image_w=image_w,
                                         image_h=image_h,
                                         boxes=boxes)
        targets, candidate = modify_annotation_by_proposal_crop_areas(candidates=candidates,
                                                                      input_w=input_w,
                                                                      input_h=input_h,
                                                                      flip=flip,
                                                                      boxes=boxes,
                                                                      labels=labels,
                                                                      landms=landms,
                                                                      allow_outside_center=False)

        # crop image
        dx, dy, nw, nh = candidate
        image = image[dy:(dy + nh), dx:(dx + nw)]

        if nw != nh:
            assert nw == image_w and nh == image_h
            # pad ori image to square
            size = max(nw, nh)
            t_image = np.empty((size, size, 3), dtype=image.dtype)
            t_image[:, :] = (104, 117, 123)
            t_image[:nh, :nw] = image
            image = t_image

        interp = get_interp_method(interp=10)
        image = cv2.resize(image, (input_w, input_h), interpolation=pil_image_reshape(interp))

        if flip:
            image = image[:, ::-1]

        image = image.astype(np.float32)
        return image, targets


@ClassFactory.register(ModuleType.PIPELINE)
class EqualProportionResize:
    """Proportional zoom."""
    def __init__(self, target_size, max_size):
        self.target_size = target_size
        self.max_size = max_size

    def __call__(self, results):
        img = np.float32(results.get("image"))
        im_size_min = np.min(img.shape[0:2])
        im_size_max = np.max(img.shape[0:2])
        resize = float(self.target_size) / float(
            im_size_min)
        # prevent bigger axis from being more than max_size:
        if np.round(resize * im_size_max) > self.max_size:
            resize = float(self.max_size) / float(im_size_max)

        img = cv2.resize(img, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR)

        assert img.shape[0] <= self.max_size and img.shape[1] <= self.max_size
        image_t = np.empty((self.max_size, self.max_size, 3), dtype=img.dtype)
        image_t[:, :] = (104.0, 117.0, 123.0)
        image_t[0:img.shape[0], 0:img.shape[1]] = img
        img = image_t

        scale = np.array([img.shape[1], img.shape[0], img.shape[1], img.shape[0]], dtype=img.dtype)
        results['image'] = img
        results['resize'] = np.array(resize, dtype=np.float32)
        results['scale'] = scale
        return results
